# -*- coding: utf-8 -*-
"""
Created on Wed Mar 12 17:29:28 2025

@author: jingla
"""
import MultiPyVu as mpv
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy.signal import savgol_filter, find_peaks
font = {'family' : 'DejaVu Sans',
        'weight' : 'normal',
        'size'   : '15'}
plt.rc('font', **font)

def plot_scope_signal_two_curves(filename1, filename2, title='', xlim=[], dt1=0, dt2=0, V0=0.1):
    nu1, t1_s, Vsd1_V, Isd1_mA = np.loadtxt(filename1, skiprows=1, delimiter=',').T
    nu2, t2_s, Vsd2_V, Isd2_mA = np.loadtxt(filename2, skiprows=1, delimiter=',').T
    plt.figure(figsize=[4, 4])
    plt.subplot(311)
    #ax2 = ax1.twinx()
    plt.plot(t1_s * 1e3 + dt1, savgol_filter(Isd1_mA, 7, 0))
    plt.plot(t2_s * 1e3 + dt2, savgol_filter(Isd2_mA, 7, 0))
    print('I1max={} I2max={}'.format(max(Isd1_mA), max(Isd2_mA)))
    
    print('V1max={} V2max={}'.format(min(Vsd1_V+V0), max(Vsd2_V+V0)))
    plt.yticks([-2, 0, 2])
    #plt.xticks([])
    if xlim != []:
        plt.xlim(xlim)
    plt.tick_params(direction='in', top=True, right=True)
    plt.ylabel('$I$ \n(mA)', fontsize=20)
    plt.title(title, fontsize=20)
    plt.subplot(312)
    plt.plot(t1_s * 1e3 + dt1, Vsd1_V * 1e3 + V0)
    plt.yticks([-2, 0])
    plt.ylim([-2.5, 0.2])
    plt.ylabel('$V_\\mathrm{out}$ \n(mV)', fontsize=20)
    if xlim != []:
        plt.xlim(xlim)
    #plt.xticks([])
    plt.tick_params(direction='in', top=True, right=True)
    plt.subplot(313)
    plt.plot(t2_s * 1e3 + dt2, Vsd2_V * 1e3 + V0, c='C1')
    if xlim != []:
        plt.xlim(xlim)
    plt.ylabel('$V_\\mathrm{out}$ \n(mV)', fontsize=20)
    plt.xlabel('$t$ (ms)', fontsize=20)
    plt.tick_params(direction='in', top=True, right=True)
    plt.yticks([0, 2])
    plt.ylim([-0.2, 2.5])
    if xlim != []:
        plt.xlim(xlim)
    return

def plot_IVs_several_B(filename, Blist=[], B0 = 0, Vran=[], single_trace=False,
                       inds=[0, 150],rev=False):
    data = mpv.DataFile()
    # 'Temperature (K)', 'Field (Oe)',
    #       'Time (s)', 'Current (A)', 'Voltage (V)'
    my_dataframe = data.parse_MVu_data_file(filename)
    field = my_dataframe['Field (Oe)'].to_numpy()/10 #mT
    current = my_dataframe['Current (A)'].to_numpy()
    voltage = my_dataframe['Voltage (V)'].to_numpy()
    if rev:
        current, voltage = -current, -voltage
    for fu in Blist:
        filt = np.round(field, 1) == fu
        Vs, Is = voltage[filt], current[filt] * 1e3
        if Vran !=[]:
            filtV = np.multiply(Vs > min(Vran), Vs < max(Vran))
            Vs, Is = Vs[filtV], Is[filtV]
        if single_trace:
            Vs, Is = Vs[min(inds):max(inds)], Is[min(inds):max(inds)]
        #plt.plot([min(Vs), max(Vs)],[0]*2,'k--')
        #plt.plot([0]*2, [min(Is), max(Is)],'k--')
        plt.plot(Vs, Is, label='B={} mT'.format(np.round(fu,0) - B0))
        #plt.title('B={} Oe'.format(np.round(fu,1)), fontsize=20)
        plt.xlabel('$V$ (V)', fontsize=20)
        plt.ylabel('$I$ (mA)', fontsize=20)
    plt.legend()
    plt.grid()
    return

def rectification_calc(B, Icp, Icm, Irp, Irm):
    rect = (Icp + Icm)/(Icp - Icm)
    plt.plot(B, np.abs(rect)*100, 'o')
    plt.tick_params(direction='in',top=True, right=True)
    plt.xlabel('$B$ (mT)', fontsize=20)
    plt.ylabel('$|\\eta|$ (%)', fontsize=20)
    in_rect = np.abs(rect) == max(np.abs(rect))
    print('Max rect={}, Ic+={} uA, Ic-={} mA, Irp={} uA, Irm={} uA, B={} Oe'.format(rect[in_rect], 
                                                                                    Icp[in_rect], 
                                                                                    Icm[in_rect],
                                                                                    Irp[in_rect], 
                                                                                    Irm[in_rect],
                                                                                    B[in_rect]))
    return B, rect

def Ic_and_efficiency_diode_1file_V2(filename = 'NiBiDev1M01_NiBi_Ch2_Bz-10to50Oe_1p7K.dat',
                                  plot_IV_curve=True, B0=0, B_dir='$B_z$ (mT)', Bran=[], 
                                  Vran=[], title='', rs=0.6, savefig=False, rev=False, Ics=False,
                                  efficiency=False):
    #global my_dataframe
    Icp_list = []
    Icm_list = []
    Icpe_list = []
    Icme_list = []
    Iretrm_list = []
    Iretrp_list = []
    
    
    data = mpv.DataFile()
    # 'Temperature (K)', 'Field (Oe)',
    #       'Time (s)', 'Current (A)', 'Voltage (V)'
    my_dataframe = data.parse_MVu_data_file(filename)
    field = my_dataframe['Field (Oe)'].to_numpy() / 10 # mT
    current = my_dataframe['Current (A)'].to_numpy()
    voltage = my_dataframe['Voltage (V)'].to_numpy()
    if rev:
        current, voltage = -current, -voltage
    field_un = field[np.sort(np.unique(field, return_index=True)[1])]
    for fu in field_un:
        filt = field == fu
        Vs, Is = voltage[filt], current[filt] * 1e3
        if Vran !=[]:
            filtV = np.multiply(Vs > min(Vran), Vs < max(Vran))
            Vs, Is = Vs[filtV], Is[filtV]
        if plot_IV_curve:
            plt.plot(Vs, Is)
            plt.minorticks_on()
            plt.grid()
            plt.title('B={} mT'.format(np.round(fu,0)), fontsize=20)
            plt.xlabel('$V$ (V)', fontsize=20)
            plt.ylabel('$I$ (mA)', fontsize=20)
            plt.show()

        dVs = np.diff(Vs, prepend=0)
        Iretrm_list.append(Is[Vs<0][np.argmax(dVs[Vs<0])])
        Iretrp_list.append(Is[Vs>0][np.argmin(dVs[Vs>0])])
        Icp = Is[np.multiply(dVs>rs*max(dVs), Is>0)]
        Icp_avg = np.mean(Icp)
        Icm = Is[np.multiply(dVs<rs*min(dVs), Is<0)]
        Icm_avg = np.mean(Icm)
        Icp_list.append(Icp_avg)
        Icm_list.append(Icm_avg)
        Icpe_list.append(np.std(Icp))
        Icme_list.append(np.std(Icm))
        
    Icp_list = np.array(Icp_list)
    Icm_list = np.array(Icm_list)
    Iretrp_list = np.array(Iretrp_list)
    Iretrm_list = np.array(Iretrm_list)
    plt.errorbar(field_un - B0, Icp_list,Icpe_list, fmt='o-', label='$I_c^+$')
    plt.errorbar(field_un - B0, -Icm_list,Icpe_list, fmt='o-', label='$I_c^-$')
    plt.legend()
    if Bran != []:
        plt.xlim(min(Bran), max(Bran))
    plt.tick_params(direction='in',top=True, right=True)
    plt.xlabel('{}'.format(B_dir), fontsize=20)
    plt.ylabel('$I_c$ (mA)', fontsize=20)
    if title != '':
        plt.title(title)
    if savefig:
        plt.savefig(filename[:-3]+'pdf')
    
    if efficiency:
        plt.show()
        rectification_calc(field_un - B0, Icp_list, Icm_list,Iretrp_list,Iretrm_list)
        if Bran != []:
            plt.xlim(min(Bran), max(Bran))
            plt.xlabel('{}'.format(B_dir), fontsize=20)
        plt.show()
    if Ics:
        return field_un, Icp_list, Icm_list

def plot_scope_signal_final(filename, title='', xlim=[], ylim=[], dt=0, tzero=False, l_win=1,
                      peak_sep=20, plot_maxs=True, max_ref=0, max_filt=0,rev=False):
    nu, t_s, Vsd_V, Isd_mA = np.loadtxt(filename, skiprows=1, delimiter=',').T
    if rev:
        Vsd_V = -Vsd_V
    if tzero:
        t_s = t_s - t_s[0]
    # Upper panel, input current
    plt.subplot(211)
    Isd_mA = savgol_filter(Isd_mA, 5, 0)
    Ilim = [min(Isd_mA), max(Isd_mA)]
    plt.plot(t_s * 1e3 + dt, Isd_mA)
    print('dt={}s'.format(max(t_s) - min(t_s)))
    
    if min(Ilim) > -2 and max(Ilim) < 2:
        plt.ylim([-2, 2])
    plt.tick_params(direction='in', top=True, right=True)
    plt.ylabel('$I_\\mathrm{in}$\n(mA)', fontsize=20)
    plt.title(title, fontsize=20)
    # Removes x tick labels if the range is specified
    if xlim != []:
        plt.xlim(xlim)
        plt.gcf().canvas.draw()
        ticks = [tick for tick in plt.gca().get_xticklabels()]
        x_tick = []
        for i, t in enumerate(ticks):
            text = t.get_text().replace('−', '-')
            if text:  # Check if the text is not empty
                try:
                    x_tick.append(float(text))
                except ValueError:
                    print(f"Warning: Tick label '{text}' is not a valid number and will be skipped.")
            else:
                print("Warning: Empty tick label found and will be skipped.")
        plt.xticks(ticks=x_tick, labels=[])
    # Lower panel, output voltage
    plt.subplot(212)
    Vsd_V = savgol_filter(Vsd_V, l_win, 0)
    plt.plot(t_s * 1e3 + dt, Vsd_V * 1e3)
    if xlim != []:
        plt.xlim(xlim)
        filt = np.multiply(t_s * 1e3 + dt > min(xlim), t_s * 1e3 + dt < max(xlim))
    else:
        filt = [True] * len(t_s)
    print('Imax={} mA'.format(max(np.abs(Isd_mA[filt]))))
    if xlim != []:
        plt.xlim(xlim)
    if ylim != []:
        plt.ylim(ylim)
    plt.ylabel('$V_\\mathrm{out}$\n(mV)', fontsize=20)
    plt.xlabel('$t$ (ms)', fontsize=20)
    plt.tick_params(direction='in', top=True, right=True)
    if max_ref == 0:
        max_ref = max(Vsd_V)
    peaks = find_peaks(Vsd_V, distance = 20, prominence = max_ref*0.9)
    peaks = peaks[0]
    if max_filt != 0:
        peaks = peaks[Vsd_V[peaks] < max_filt]
    Vmax, Vmaxerr = np.mean(Vsd_V[peaks]), np.std(Vsd_V[peaks])
    
    Vmin_ref = min(Vsd_V)+ 0.05 * (max(Vsd_V) - min(Vsd_V))
    Vmin, Vminerr = np.mean(Vsd_V[Vsd_V < Vmin_ref]), np.std(Vsd_V[Vsd_V < Vmin_ref])
    Vampl = Vmax - Vmin
    dV = np.sqrt(Vmaxerr ** 2 + Vminerr ** 2)
    print(f'Number of peaks {len(Vsd_V[peaks])}')
    print('Amplitude={} $\\pm$ {} mV'.format(Vampl * 1e3, dV * 1e3))
    print(f'Average Vmax={Vmax*1e3}$\\pm${Vmaxerr*1e3} mV')
    if plot_maxs:
        plt.plot(t_s[peaks]* 1e3 + dt, Vsd_V[peaks]* 1e3, 'x')
        plt.plot(t_s[Vsd_V < Vmin_ref]* 1e3 + dt, Vsd_V[Vsd_V < Vmin_ref]* 1e3, '-')
    return
